{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eCes7jVU8r08"
      },
      "source": [
        "##### Copyright 2023 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "pc1j3ZVF8mmG"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SUUX9CnCYI9Y"
      },
      "source": [
        "# Instance Segmentation with Model Garden\n",
        "\n",
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://www.tensorflow.org/tfmodels/vision/instance_segmentation\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/models/blob/master/docs/vision/instance_segmentation.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/models/blob/master/docs/vision/instance_segmentation.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View on GitHub</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a href=\"https://storage.googleapis.com/tensorflow_docs/models/docs/vision/instance_segmentation.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UjP7bQUdTeFr"
      },
      "source": [
        "This tutorial fine-tunes a [Mask R-CNN](https://arxiv.org/abs/1703.06870) with [Mobilenet V2](https://arxiv.org/abs/1801.04381) as backbone model from the [TensorFlow Model Garden](https://pypi.org/project/tf-models-official/) package (tensorflow-models).\n",
        "\n",
        "\n",
        "[Model Garden](https://www.tensorflow.org/tfmodels) contains a collection of state-of-the-art models, implemented with TensorFlow's high-level APIs. The implementations demonstrate the best practices for modeling, letting users to take full advantage of TensorFlow for their research and product development.\n",
        "\n",
        "This tutorial demonstrates how to:\n",
        "\n",
        "1. Use models from the TensorFlow Models package.\n",
        "2. Train/Fine-tune a pre-built Mask R-CNN with mobilenet as backbone for Object Detection and Instance Segmentation\n",
        "3. Export the trained/tuned Mask R-CNN model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RDp6Kk1Baoi4"
      },
      "source": [
        "## Install Necessary Dependencies"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "hcl98qUOxlL8"
      },
      "outputs": [],
      "source": [
        "!pip install -U -q \"tf-models-official\"\n",
        "!pip install -U -q remotezip tqdm opencv-python einops"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5-gCe_YTapey"
      },
      "source": [
        "## Import required libraries"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Qa9552Ukgf3d"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "import io\n",
        "import json\n",
        "import tqdm\n",
        "import shutil\n",
        "import pprint\n",
        "import pathlib\n",
        "import tempfile\n",
        "import requests\n",
        "import collections\n",
        "import matplotlib\n",
        "import numpy as np\n",
        "import tensorflow as tf\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "from PIL import Image\n",
        "from six import BytesIO\n",
        "from etils import epath\n",
        "from IPython import display\n",
        "from urllib.request import urlopen"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tSCMIDRDP2fV"
      },
      "outputs": [],
      "source": [
        "import orbit\n",
        "import tensorflow as tf\n",
        "import tensorflow_models as tfm\n",
        "import tensorflow_datasets as tfds\n",
        "\n",
        "from official.core import exp_factory\n",
        "from official.core import config_definitions as cfg\n",
        "from official.vision.data import tfrecord_lib\n",
        "from official.vision.serving import export_saved_model_lib\n",
        "from official.vision.dataloaders.tf_example_decoder import TfExampleDecoder\n",
        "from official.vision.utils.object_detection import visualization_utils\n",
        "from official.vision.ops.preprocess_ops import normalize_image, resize_and_crop_image\n",
        "from official.vision.data.create_coco_tf_record import coco_annotations_to_lists\n",
        "\n",
        "pp = pprint.PrettyPrinter(indent=4) # Set Pretty Print Indentation\n",
        "print(tf.__version__) # Check the version of tensorflow used\n",
        "\n",
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GIrXW8sp2bKa"
      },
      "source": [
        "## Download subset of lvis dataset\n",
        "\n",
        "[LVIS](https://www.tensorflow.org/datasets/catalog/lvis): A dataset for large vocabulary instance segmentation.\n",
        "\n",
        "Note: LVIS uses the COCO 2017 train, validation, and test image sets. \n",
        "If you have already downloaded the COCO images, you only need to download \n",
        "the LVIS annotations. LVIS val set contains images from COCO 2017 train in \n",
        "addition to the COCO 2017 val split."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "F_A9_cS310jf"
      },
      "outputs": [],
      "source": [
        "# @title Download annotation files\n",
        "\n",
        "!wget https://dl.fbaipublicfiles.com/LVIS/lvis_v1_train.json.zip\n",
        "!unzip -q lvis_v1_train.json.zip\n",
        "!rm lvis_v1_train.json.zip\n",
        "\n",
        "!wget https://dl.fbaipublicfiles.com/LVIS/lvis_v1_val.json.zip\n",
        "!unzip -q lvis_v1_val.json.zip\n",
        "!rm lvis_v1_val.json.zip\n",
        "\n",
        "!wget https://dl.fbaipublicfiles.com/LVIS/lvis_v1_image_info_test_dev.json.zip\n",
        "!unzip -q lvis_v1_image_info_test_dev.json.zip\n",
        "!rm lvis_v1_image_info_test_dev.json.zip"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "kB-C5Svj11S0"
      },
      "outputs": [],
      "source": [
        "# @title Lvis annotation parsing\n",
        "\n",
        "# Annotations with invalid bounding boxes. Will not be used.\n",
        "_INVALID_ANNOTATIONS = [\n",
        "    # Train split.\n",
        "    662101,\n",
        "    81217,\n",
        "    462924,\n",
        "    227817,\n",
        "    29381,\n",
        "    601484,\n",
        "    412185,\n",
        "    504667,\n",
        "    572573,\n",
        "    91937,\n",
        "    239022,\n",
        "    181534,\n",
        "    101685,\n",
        "    # Validation split.\n",
        "    36668,\n",
        "    57541,\n",
        "    33126,\n",
        "    10932,\n",
        "]\n",
        "\n",
        "def get_category_map(annotation_path, num_classes):\n",
        "  with epath.Path(annotation_path).open() as f:\n",
        "      data = json.load(f)\n",
        "\n",
        "  category_map = {id+1: {'id': cat_dict['id'],\n",
        "                       'name': cat_dict['name']}\n",
        "                  for id, cat_dict in enumerate(data['categories'][:num_classes])}\n",
        "  return category_map\n",
        "\n",
        "class LvisAnnotation:\n",
        "  \"\"\"LVIS annotation helper class.\n",
        "  The format of the annations is explained on\n",
        "  https://www.lvisdataset.org/dataset.\n",
        "  \"\"\"\n",
        "\n",
        "  def __init__(self, annotation_path):\n",
        "    with epath.Path(annotation_path).open() as f:\n",
        "      data = json.load(f)\n",
        "    self._data = data\n",
        "\n",
        "    img_id2annotations = collections.defaultdict(list)\n",
        "    for a in self._data.get('annotations', []):\n",
        "      if a['category_id'] in category_ids:\n",
        "        img_id2annotations[a['image_id']].append(a)\n",
        "    self._img_id2annotations = {\n",
        "        k: list(sorted(v, key=lambda a: a['id']))\n",
        "        for k, v in img_id2annotations.items()\n",
        "    }\n",
        "\n",
        "  @property\n",
        "  def categories(self):\n",
        "    \"\"\"Return the category dicts, as sorted in the file.\"\"\"\n",
        "    return self._data['categories']\n",
        "\n",
        "  @property\n",
        "  def images(self):\n",
        "    \"\"\"Return the image dicts, as sorted in the file.\"\"\"\n",
        "    sub_images = []\n",
        "    for image_info in self._data['images']:\n",
        "      if image_info['id'] in self._img_id2annotations:\n",
        "        sub_images.append(image_info)\n",
        "    return sub_images\n",
        "\n",
        "  def get_annotations(self, img_id):\n",
        "    \"\"\"Return all annotations associated with the image id string.\"\"\"\n",
        "    # Some images don't have any annotations. Return empty list instead.\n",
        "    return self._img_id2annotations.get(img_id, [])\n",
        "\n",
        "def _generate_tf_records(prefix, images_zip, annotation_file, num_shards=5):\n",
        "    \"\"\"Generate TFRecords.\"\"\"\n",
        "\n",
        "    lvis_annotation = LvisAnnotation(annotation_file)\n",
        "\n",
        "    def _process_example(prefix, image_info, id_to_name_map):\n",
        "      # Search image dirs.\n",
        "      filename = pathlib.Path(image_info['coco_url']).name\n",
        "      image = tf.io.read_file(os.path.join(IMGS_DIR, filename))\n",
        "      instances = lvis_annotation.get_annotations(img_id=image_info['id'])\n",
        "      instances = [x for x in instances if x['id'] not in _INVALID_ANNOTATIONS]\n",
        "      # print([x['category_id'] for x in instances])\n",
        "      is_crowd = {'iscrowd': 0}\n",
        "      instances = [dict(x, **is_crowd) for x in instances]\n",
        "      neg_category_ids = image_info.get('neg_category_ids', [])\n",
        "      not_exhaustive_category_ids = image_info.get(\n",
        "          'not_exhaustive_category_ids', []\n",
        "      )\n",
        "      data, _ = coco_annotations_to_lists(instances,\n",
        "                                          id_to_name_map,\n",
        "                                          image_info['height'],\n",
        "                                          image_info['width'],\n",
        "                                          include_masks=True)\n",
        "      # data['category_id'] = [id-1 for id in data['category_id']]\n",
        "      keys_to_features = {\n",
        "          'image/encoded':\n",
        "              tfrecord_lib.convert_to_feature(image.numpy()),\n",
        "          'image/filename':\n",
        "               tfrecord_lib.convert_to_feature(filename.encode('utf8')),\n",
        "          'image/format':\n",
        "              tfrecord_lib.convert_to_feature('jpg'.encode('utf8')),\n",
        "          'image/height':\n",
        "              tfrecord_lib.convert_to_feature(image_info['height']),\n",
        "          'image/width':\n",
        "              tfrecord_lib.convert_to_feature(image_info['width']),\n",
        "          'image/source_id':\n",
        "              tfrecord_lib.convert_to_feature(str(image_info['id']).encode('utf8')),\n",
        "          'image/object/bbox/xmin':\n",
        "              tfrecord_lib.convert_to_feature(data['xmin']),\n",
        "          'image/object/bbox/xmax':\n",
        "              tfrecord_lib.convert_to_feature(data['xmax']),\n",
        "          'image/object/bbox/ymin':\n",
        "              tfrecord_lib.convert_to_feature(data['ymin']),\n",
        "          'image/object/bbox/ymax':\n",
        "              tfrecord_lib.convert_to_feature(data['ymax']),\n",
        "          'image/object/class/text':\n",
        "              tfrecord_lib.convert_to_feature(data['category_names']),\n",
        "          'image/object/class/label':\n",
        "              tfrecord_lib.convert_to_feature(data['category_id']),\n",
        "          'image/object/is_crowd':\n",
        "              tfrecord_lib.convert_to_feature(data['is_crowd']),\n",
        "          'image/object/area':\n",
        "              tfrecord_lib.convert_to_feature(data['area'], 'float_list'),\n",
        "          'image/object/mask':\n",
        "              tfrecord_lib.convert_to_feature(data['encoded_mask_png'])\n",
        "      }\n",
        "      # print(keys_to_features['image/object/class/label'])\n",
        "      example = tf.train.Example(\n",
        "          features=tf.train.Features(feature=keys_to_features))\n",
        "      return example\n",
        "\n",
        "\n",
        "\n",
        "    # file_names = [f\"{prefix}/{pathlib.Path(image_info['coco_url']).name}\"\n",
        "    #               for image_info in lvis_annotation.images]\n",
        "    # _extract_images(images_zip, file_names)\n",
        "    writers = [\n",
        "        tf.io.TFRecordWriter(\n",
        "            tf_records_dir + prefix +'-%05d-of-%05d.tfrecord' % (i, num_shards))\n",
        "        for i in range(num_shards)\n",
        "    ]\n",
        "    id_to_name_map = {cat_dict['id']: cat_dict['name']\n",
        "                      for cat_dict in lvis_annotation.categories[:NUM_CLASSES]}\n",
        "    # print(id_to_name_map)\n",
        "    for idx, image_info in enumerate(tqdm.tqdm(lvis_annotation.images)):\n",
        "      img_data = requests.get(image_info['coco_url'], stream=True).content\n",
        "      img_name = image_info['coco_url'].split('/')[-1]\n",
        "      with open(os.path.join(IMGS_DIR, img_name), 'wb') as handler:\n",
        "          handler.write(img_data)\n",
        "      tf_example = _process_example(prefix, image_info, id_to_name_map)\n",
        "      writers[idx % num_shards].write(tf_example.SerializeToString())\n",
        "\n",
        "    del lvis_annotation"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5u2dwjIT2HZu"
      },
      "outputs": [],
      "source": [
        "_URLS = {\n",
        "    'train_images': 'http://images.cocodataset.org/zips/train2017.zip',\n",
        "    'validation_images': 'http://images.cocodataset.org/zips/val2017.zip',\n",
        "    'test_images': 'http://images.cocodataset.org/zips/test2017.zip',\n",
        "}\n",
        "\n",
        "train_prefix = 'train'\n",
        "valid_prefix = 'val'\n",
        "\n",
        "train_annotation_path = './lvis_v1_train.json'\n",
        "valid_annotation_path = './lvis_v1_val.json'\n",
        "\n",
        "IMGS_DIR = './lvis_sub_dataset/'\n",
        "tf_records_dir = './lvis_tfrecords/'\n",
        "\n",
        "\n",
        "if not os.path.exists(IMGS_DIR):\n",
        "  os.mkdir(IMGS_DIR)\n",
        "\n",
        "if not os.path.exists(tf_records_dir):\n",
        "  os.mkdir(tf_records_dir)\n",
        "\n",
        "\n",
        "\n",
        "NUM_CLASSES = 3\n",
        "category_index = get_category_map(valid_annotation_path, NUM_CLASSES)\n",
        "category_ids = list(category_index.keys())"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KBgl5fG42LpD"
      },
      "outputs": [],
      "source": [
        "# Below helper function are taken from github tensorflow dataset lvis\n",
        "# https://github.com/tensorflow/datasets/blob/master/tensorflow_datasets/datasets/lvis/lvis_dataset_builder.py\n",
        "_generate_tf_records(train_prefix,\n",
        "                     _URLS['train_images'],\n",
        "                     train_annotation_path)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "89O59u_H2NIJ"
      },
      "outputs": [],
      "source": [
        "_generate_tf_records(valid_prefix,\n",
        "                     _URLS['validation_images'],\n",
        "                     valid_annotation_path)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EREyevfIY4rz"
      },
      "source": [
        "## Configure the MaskRCNN Resnet FPN COCO model for custom dataset"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5yGLLvXlPInP"
      },
      "outputs": [],
      "source": [
        "train_data_input_path = './lvis_tfrecords/train*'\n",
        "valid_data_input_path = './lvis_tfrecords/val*'\n",
        "test_data_input_path = './lvis_tfrecords/test*'\n",
        "model_dir = './trained_model/'\n",
        "export_dir ='./exported_model/'"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ms3wRQKAIORe"
      },
      "outputs": [],
      "source": [
        "if not os.path.exists(model_dir):\n",
        "  os.mkdir(model_dir)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EXA5NmvDblYP"
      },
      "source": [
        "In Model Garden, the collections of parameters that define a model are called *configs*. Model Garden can create a config based on a known set of parameters via a [factory](https://en.wikipedia.org/wiki/Factory_method_pattern).\n",
        "\n",
        "\n",
        "Use the `retinanet_mobilenet_coco` experiment configuration, as defined by `tfm.vision.configs.maskrcnn.maskrcnn_mobilenet_coco`.\n",
        "\n",
        "Please find all the registered experiements [here](https://www.tensorflow.org/api_docs/python/tfm/core/exp_factory/get_exp_config)\n",
        "\n",
        "The configuration defines an experiment to train a Mask R-CNN model with mobilenet as backbone and FPN as decoder. Default Congiguration is trained on [COCO](https://cocodataset.org/) train2017 and evaluated on [COCO](https://cocodataset.org/) val2017.\n",
        "\n",
        "There are also other alternative experiments available such as\n",
        "`maskrcnn_resnetfpn_coco`,\n",
        "`maskrcnn_spinenet_coco` and more. One can switch to them by changing the experiment name argument to the `get_exp_config` function."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Zi2F1qGgPWOH"
      },
      "outputs": [],
      "source": [
        "exp_config = exp_factory.get_exp_config('maskrcnn_mobilenet_coco')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zo-EaCdmn5j-"
      },
      "outputs": [],
      "source": [
        "model_ckpt_path = './model_ckpt/'\n",
        "if not os.path.exists(model_ckpt_path):\n",
        "  os.mkdir(model_ckpt_path)\n",
        "\n",
        "!gsutil cp gs://tf_model_garden/vision/mobilenet/v2_1.0_float/ckpt-180648.data-00000-of-00001 './model_ckpt/'\n",
        "!gsutil cp gs://tf_model_garden/vision/mobilenet/v2_1.0_float/ckpt-180648.index './model_ckpt/'"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ymnwJYaFgHs2"
      },
      "source": [
        "### Adjust the model and dataset configurations so that it works with custom dataset."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zyn9ieZyUbEJ"
      },
      "outputs": [],
      "source": [
        "BATCH_SIZE = 8\n",
        "HEIGHT, WIDTH = 256, 256\n",
        "IMG_SHAPE = [HEIGHT, WIDTH, 3]\n",
        "\n",
        "\n",
        "# Backbone Config\n",
        "exp_config.task.annotation_file = None\n",
        "exp_config.task.freeze_backbone = True\n",
        "exp_config.task.init_checkpoint = \"./model_ckpt/ckpt-180648\"\n",
        "exp_config.task.init_checkpoint_modules = \"backbone\"\n",
        "\n",
        "# Model Config\n",
        "exp_config.task.model.num_classes = NUM_CLASSES + 1\n",
        "exp_config.task.model.input_size = IMG_SHAPE\n",
        "\n",
        "# Training Data Config\n",
        "exp_config.task.train_data.input_path = train_data_input_path\n",
        "exp_config.task.train_data.dtype = 'float32'\n",
        "exp_config.task.train_data.global_batch_size = BATCH_SIZE\n",
        "exp_config.task.train_data.shuffle_buffer_size = 64\n",
        "exp_config.task.train_data.parser.aug_scale_max = 1.0\n",
        "exp_config.task.train_data.parser.aug_scale_min = 1.0\n",
        "\n",
        "# Validation Data Config\n",
        "exp_config.task.validation_data.input_path = valid_data_input_path\n",
        "exp_config.task.validation_data.dtype = 'float32'\n",
        "exp_config.task.validation_data.global_batch_size = BATCH_SIZE"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0409ReANgKzF"
      },
      "source": [
        "### Adjust the trainer configuration."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ne8t5AHRUd9g"
      },
      "outputs": [],
      "source": [
        "logical_device_names = [logical_device.name for logical_device in tf.config.list_logical_devices()]\n",
        "\n",
        "if 'GPU' in ''.join(logical_device_names):\n",
        "  print('This may be broken in Colab.')\n",
        "  device = 'GPU'\n",
        "elif 'TPU' in ''.join(logical_device_names):\n",
        "  print('This may be broken in Colab.')\n",
        "  device = 'TPU'\n",
        "else:\n",
        "  print('Running on CPU is slow, so only train for a few steps.')\n",
        "  device = 'CPU'\n",
        "\n",
        "\n",
        "train_steps = 2000\n",
        "exp_config.trainer.steps_per_loop = 200 # steps_per_loop = num_of_training_examples // train_batch_size\n",
        "\n",
        "exp_config.trainer.summary_interval = 200\n",
        "exp_config.trainer.checkpoint_interval = 200\n",
        "exp_config.trainer.validation_interval = 200\n",
        "exp_config.trainer.validation_steps =  200 # validation_steps = num_of_validation_examples // eval_batch_size\n",
        "exp_config.trainer.train_steps = train_steps\n",
        "exp_config.trainer.optimizer_config.warmup.linear.warmup_steps = 200\n",
        "exp_config.trainer.optimizer_config.learning_rate.type = 'cosine'\n",
        "exp_config.trainer.optimizer_config.learning_rate.cosine.decay_steps = train_steps\n",
        "exp_config.trainer.optimizer_config.learning_rate.cosine.initial_learning_rate = 0.07\n",
        "exp_config.trainer.optimizer_config.warmup.linear.warmup_learning_rate = 0.05"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "k3I4X-bWgNm0"
      },
      "source": [
        "### Print the modified configuration."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IsmxXNlyWBAK"
      },
      "outputs": [],
      "source": [
        "pp.pprint(exp_config.as_dict())\n",
        "display.Javascript(\"google.colab.output.setIframeHeight('500px');\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jxarWEHDgQSk"
      },
      "source": [
        "### Set up the distribution strategy."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4JxhiGNwQRv2"
      },
      "outputs": [],
      "source": [
        "# Setting up the Strategy\n",
        "if exp_config.runtime.mixed_precision_dtype == tf.float16:\n",
        "    tf.keras.mixed_precision.set_global_policy('mixed_float16')\n",
        "\n",
        "if 'GPU' in ''.join(logical_device_names):\n",
        "  distribution_strategy = tf.distribute.MirroredStrategy()\n",
        "elif 'TPU' in ''.join(logical_device_names):\n",
        "  tf.tpu.experimental.initialize_tpu_system()\n",
        "  tpu = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='/device:TPU_SYSTEM:0')\n",
        "  distribution_strategy = tf.distribute.experimental.TPUStrategy(tpu)\n",
        "else:\n",
        "  print('Warning: this will be really slow.')\n",
        "  distribution_strategy = tf.distribute.OneDeviceStrategy(logical_device_names[0])\n",
        "\n",
        "print(\"Done\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QqZU9f1ugS_A"
      },
      "source": [
        "## Create the `Task` object (`tfm.core.base_task.Task`) from the `config_definitions.TaskConfig`.\n",
        "\n",
        "The `Task` object has all the methods necessary for building the dataset, building the model, and running training & evaluation. These methods are driven by `tfm.core.train_lib.run_experiment`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "N5R-7KzORB1n"
      },
      "outputs": [],
      "source": [
        "with distribution_strategy.scope():\n",
        "  task = tfm.core.task_factory.get_task(exp_config.task, logging_dir=model_dir)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Fmpz2R_cglIv"
      },
      "source": [
        "## Visualize a batch of the data."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "O82f_7A8gfnY"
      },
      "outputs": [],
      "source": [
        "for images, labels in task.build_inputs(exp_config.task.train_data).take(1):\n",
        "  print()\n",
        "  print(f'images.shape: {str(images.shape):16}  images.dtype: {images.dtype!r}')\n",
        "  print(f'labels.keys: {labels.keys()}')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dLcSHWjqgl66"
      },
      "source": [
        "### Create Category Index Dictionary to map the labels to coressponding label names"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ajF85r_6R9d9"
      },
      "outputs": [],
      "source": [
        "tf_ex_decoder = TfExampleDecoder(include_mask=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gRdveeYVgr7B"
      },
      "source": [
        "### Helper Function for Visualizing the results from TFRecords\n",
        "Use `visualize_boxes_and_labels_on_image_array` from `visualization_utils` to draw boudning boxes on the image."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "uWEuOs8QStrz"
      },
      "outputs": [],
      "source": [
        "def show_batch(raw_records):\n",
        "  plt.figure(figsize=(20, 20))\n",
        "  use_normalized_coordinates=True\n",
        "  min_score_thresh = 0.30\n",
        "  for i, serialized_example in enumerate(raw_records):\n",
        "    plt.subplot(1, 3, i + 1)\n",
        "    decoded_tensors = tf_ex_decoder.decode(serialized_example)\n",
        "    image = decoded_tensors['image'].numpy().astype('uint8')\n",
        "    scores = np.ones(shape=(len(decoded_tensors['groundtruth_boxes'])))\n",
        "    # print(decoded_tensors['groundtruth_instance_masks'].numpy().shape)\n",
        "    # print(decoded_tensors.keys())\n",
        "    visualization_utils.visualize_boxes_and_labels_on_image_array(\n",
        "        image,\n",
        "        decoded_tensors['groundtruth_boxes'].numpy(),\n",
        "        decoded_tensors['groundtruth_classes'].numpy().astype('int'),\n",
        "        scores,\n",
        "        category_index=category_index,\n",
        "        use_normalized_coordinates=use_normalized_coordinates,\n",
        "        min_score_thresh=min_score_thresh,\n",
        "        instance_masks=decoded_tensors['groundtruth_instance_masks'].numpy().astype('uint8'),\n",
        "        line_thickness=4)\n",
        "\n",
        "    plt.imshow(image)\n",
        "    plt.axis(\"off\")\n",
        "    plt.title(f\"Image-{i+1}\")\n",
        "  plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FergQ2P5gv_j"
      },
      "source": [
        "### Visualization of Train Data\n",
        "\n",
        "The bounding box detection has three components\n",
        "  1. Class label of the object detected.\n",
        "  2. Percentage of match between predicted and ground truth bounding boxes.\n",
        "  3. Instance Segmentation Mask\n",
        "\n",
        "**Note**: The reason of everything is 100% is because we are visualising the groundtruth"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lN0zdBwxU5Z5"
      },
      "outputs": [],
      "source": [
        "buffer_size = 100\n",
        "num_of_examples = 3\n",
        "\n",
        "train_tfrecords = tf.io.gfile.glob(exp_config.task.train_data.input_path)\n",
        "raw_records = tf.data.TFRecordDataset(train_tfrecords).shuffle(buffer_size=buffer_size).take(num_of_examples)\n",
        "show_batch(raw_records)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nn7IZSs5hQLg"
      },
      "source": [
        "## Train and evaluate\n",
        "\n",
        "We follow the COCO challenge tradition to evaluate the accuracy of object detection based on mAP(mean Average Precision). Please check [here](https://cocodataset.org/#detection-eval) for detail explanation of how evaluation metrics for detection task is done.\n",
        "\n",
        "**IoU**: is defined as the area of the intersection divided by the area of the union of a predicted bounding box and ground truth bounding box."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "UTuIs4kFZGv_"
      },
      "outputs": [],
      "source": [
        "model, eval_logs = tfm.core.train_lib.run_experiment(\n",
        "    distribution_strategy=distribution_strategy,\n",
        "    task=task,\n",
        "    mode='train_and_eval',\n",
        "    params=exp_config,\n",
        "    model_dir=model_dir,\n",
        "    run_post_eval=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rfpH4QHkh1gI"
      },
      "source": [
        "## Load logs in tensorboard"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wcdOvg6eNP6R"
      },
      "outputs": [],
      "source": [
        "%load_ext tensorboard\n",
        "%tensorboard --logdir \"./trained_model\""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hAo9lozJh2cV"
      },
      "source": [
        "## Saving and exporting the trained model\n",
        "\n",
        "The `keras.Model` object returned by `train_lib.run_experiment` expects the data to be normalized by the dataset loader using the same mean and variance statiscics in `preprocess_ops.normalize_image(image, offset=MEAN_RGB, scale=STDDEV_RGB)`. This export function handles those details, so you can pass `tf.uint8` images and get the correct results."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "iZG1vPbTQqFh"
      },
      "outputs": [],
      "source": [
        "export_saved_model_lib.export_inference_graph(\n",
        "    input_type='image_tensor',\n",
        "    batch_size=1,\n",
        "    input_image_size=[HEIGHT, WIDTH],\n",
        "    params=exp_config,\n",
        "    checkpoint_path=tf.train.latest_checkpoint(model_dir),\n",
        "    export_dir=export_dir)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "OHIfMeVXh7vJ"
      },
      "source": [
        "## Inference from Trained Model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "uaXyzMvXROTd"
      },
      "outputs": [],
      "source": [
        "def load_image_into_numpy_array(path):\n",
        "  \"\"\"Load an image from file into a numpy array.\n",
        "\n",
        "  Puts image into numpy array to feed into tensorflow graph.\n",
        "  Note that by convention we put it into a numpy array with shape\n",
        "  (height, width, channels), where channels=3 for RGB.\n",
        "\n",
        "  Args:\n",
        "    path: the file path to the image\n",
        "\n",
        "  Returns:\n",
        "    uint8 numpy array with shape (img_height, img_width, 3)\n",
        "  \"\"\"\n",
        "  image = None\n",
        "  if(path.startswith('http')):\n",
        "    response = urlopen(path)\n",
        "    image_data = response.read()\n",
        "    image_data = BytesIO(image_data)\n",
        "    image = Image.open(image_data)\n",
        "  else:\n",
        "    image_data = tf.io.gfile.GFile(path, 'rb').read()\n",
        "    image = Image.open(BytesIO(image_data))\n",
        "\n",
        "  (im_width, im_height) = image.size\n",
        "  return np.array(image.getdata()).reshape(\n",
        "      (1, im_height, im_width, 3)).astype(np.uint8)\n",
        "\n",
        "\n",
        "\n",
        "def build_inputs_for_object_detection(image, input_image_size):\n",
        "  \"\"\"Builds Object Detection model inputs for serving.\"\"\"\n",
        "  image, _ = resize_and_crop_image(\n",
        "      image,\n",
        "      input_image_size,\n",
        "      padded_size=input_image_size,\n",
        "      aug_scale_min=1.0,\n",
        "      aug_scale_max=1.0)\n",
        "  return image"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZDI9zv_4h-7-"
      },
      "source": [
        "## Visualize test data"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rdyIri-1RThk"
      },
      "outputs": [],
      "source": [
        "num_of_examples = 3\n",
        "\n",
        "test_tfrecords = tf.io.gfile.glob('./lvis_tfrecords/val*')\n",
        "test_ds = tf.data.TFRecordDataset(test_tfrecords).take(num_of_examples)\n",
        "show_batch(test_ds)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KkMZm4DtiAHO"
      },
      "source": [
        "## Importing SavedModel"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rDozz4NXRZ7p"
      },
      "outputs": [],
      "source": [
        "imported = tf.saved_model.load(export_dir)\n",
        "model_fn = imported.signatures['serving_default']"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DUxk4-AjLAcO"
      },
      "source": [
        "## Visualize predictions"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Gez57T5ShYnM"
      },
      "outputs": [],
      "source": [
        "def reframe_image_corners_relative_to_boxes(boxes):\n",
        "  \"\"\"Reframe the image corners ([0, 0, 1, 1]) to be relative to boxes.\n",
        "  The local coordinate frame of each box is assumed to be relative to\n",
        "  its own for corners.\n",
        "  Args:\n",
        "    boxes: A float tensor of [num_boxes, 4] of (ymin, xmin, ymax, xmax)\n",
        "      coordinates in relative coordinate space of each bounding box.\n",
        "  Returns:\n",
        "    reframed_boxes: Reframes boxes with same shape as input.\n",
        "  \"\"\"\n",
        "  ymin, xmin, ymax, xmax = (boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3])\n",
        "\n",
        "  height = tf.maximum(ymax - ymin, 1e-4)\n",
        "  width = tf.maximum(xmax - xmin, 1e-4)\n",
        "\n",
        "  ymin_out = (0 - ymin) / height\n",
        "  xmin_out = (0 - xmin) / width\n",
        "  ymax_out = (1 - ymin) / height\n",
        "  xmax_out = (1 - xmin) / width\n",
        "  return tf.stack([ymin_out, xmin_out, ymax_out, xmax_out], axis=1)\n",
        "\n",
        "def reframe_box_masks_to_image_masks(box_masks, boxes, image_height,\n",
        "                                     image_width, resize_method='bilinear'):\n",
        "  \"\"\"Transforms the box masks back to full image masks.\n",
        "  Embeds masks in bounding boxes of larger masks whose shapes correspond to\n",
        "  image shape.\n",
        "  Args:\n",
        "    box_masks: A tensor of size [num_masks, mask_height, mask_width].\n",
        "    boxes: A tf.float32 tensor of size [num_masks, 4] containing the box\n",
        "           corners. Row i contains [ymin, xmin, ymax, xmax] of the box\n",
        "           corresponding to mask i. Note that the box corners are in\n",
        "           normalized coordinates.\n",
        "    image_height: Image height. The output mask will have the same height as\n",
        "                  the image height.\n",
        "    image_width: Image width. The output mask will have the same width as the\n",
        "                 image width.\n",
        "    resize_method: The resize method, either 'bilinear' or 'nearest'. Note that\n",
        "      'bilinear' is only respected if box_masks is a float.\n",
        "  Returns:\n",
        "    A tensor of size [num_masks, image_height, image_width] with the same dtype\n",
        "    as `box_masks`.\n",
        "  \"\"\"\n",
        "  resize_method = 'nearest' if box_masks.dtype == tf.uint8 else resize_method\n",
        "  # TODO(rathodv): Make this a public function.\n",
        "  def reframe_box_masks_to_image_masks_default():\n",
        "    \"\"\"The default function when there are more than 0 box masks.\"\"\"\n",
        "\n",
        "    num_boxes = tf.shape(box_masks)[0]\n",
        "    box_masks_expanded = tf.expand_dims(box_masks, axis=3)\n",
        "\n",
        "    resized_crops = tf.image.crop_and_resize(\n",
        "        image=box_masks_expanded,\n",
        "        boxes=reframe_image_corners_relative_to_boxes(boxes),\n",
        "        box_indices=tf.range(num_boxes),\n",
        "        crop_size=[image_height, image_width],\n",
        "        method=resize_method,\n",
        "        extrapolation_value=0)\n",
        "    return tf.cast(resized_crops, box_masks.dtype)\n",
        "\n",
        "  image_masks = tf.cond(\n",
        "      tf.shape(box_masks)[0] > 0,\n",
        "      reframe_box_masks_to_image_masks_default,\n",
        "      lambda: tf.zeros([0, image_height, image_width, 1], box_masks.dtype))\n",
        "  return tf.squeeze(image_masks, axis=3)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6EIRAlXcSQaA"
      },
      "outputs": [],
      "source": [
        "input_image_size = (HEIGHT, WIDTH)\n",
        "plt.figure(figsize=(20, 20))\n",
        "min_score_thresh = 0.40 # Change minimum score for threshold to see all bounding boxes confidences\n",
        "\n",
        "for i, serialized_example in enumerate(test_ds):\n",
        "  plt.subplot(1, 3, i+1)\n",
        "  decoded_tensors = tf_ex_decoder.decode(serialized_example)\n",
        "  image = build_inputs_for_object_detection(decoded_tensors['image'], input_image_size)\n",
        "  image = tf.expand_dims(image, axis=0)\n",
        "  image = tf.cast(image, dtype = tf.uint8)\n",
        "  image_np = image[0].numpy()\n",
        "  result = model_fn(image)\n",
        "  # Visualize detection and masks\n",
        "  if 'detection_masks' in result:\n",
        "    # we need to convert np.arrays to tensors\n",
        "    detection_masks = tf.convert_to_tensor(result['detection_masks'][0])\n",
        "    detection_boxes = tf.convert_to_tensor(result['detection_boxes'][0])\n",
        "    detection_masks_reframed = reframe_box_masks_to_image_masks(\n",
        "              detection_masks, detection_boxes/256.0,\n",
        "                image_np.shape[0], image_np.shape[1])\n",
        "    detection_masks_reframed = tf.cast(\n",
        "        detection_masks_reframed > min_score_thresh,\n",
        "        np.uint8)\n",
        "\n",
        "    result['detection_masks_reframed'] = detection_masks_reframed.numpy()\n",
        "  visualization_utils.visualize_boxes_and_labels_on_image_array(\n",
        "        image_np,\n",
        "        result['detection_boxes'][0].numpy(),\n",
        "        (result['detection_classes'][0] + 0).numpy().astype(int),\n",
        "        result['detection_scores'][0].numpy(),\n",
        "        category_index=category_index,\n",
        "        use_normalized_coordinates=False,\n",
        "        max_boxes_to_draw=200,\n",
        "        min_score_thresh=min_score_thresh,\n",
        "        instance_masks=result.get('detection_masks_reframed', None),\n",
        "        line_thickness=4)\n",
        "\n",
        "  plt.imshow(image_np)\n",
        "  plt.axis(\"off\")\n",
        "\n",
        "plt.show()"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "name": "instance_segmentation.ipynb",
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
